
from init import PATHS
import logging
LOGGER = logging.getLogger(__name__)
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('seaborn')
import seaborn as sns
sns.set(style="white")
pal = sns.color_palette('colorblind')
from C_PatentVariables import conventions_names_colors
dict_subsector_shortnames, dict_var_colors, list_of_subsectors_in_cleancars = conventions_names_colors.main()
from E_Analysis import reading_datasets
from scipy.stats import linregress


def main():
    LOGGER.info('Begin e_othergraphs_batteries_trans_vs_nontr.py')
    # Output a series of graphs, tables and descriptives to understand the differences between transport and non transport battery trends
    plot_BatTrNonTr_familylevel()
    plot_BatTrNonTr_OEM()
    plot_BatTrNonTr_bysector()
    stackplot_BatTrNonTr_activesupplierpatenting_bytype()
    plot_activesuppliers_transvsnontrans()
    LOGGER.info('END e_othergraphs_batteries_trans_vs_nontr.py')


def plot_BatTrNonTr_familylevel():
    dfYearLevel = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/TimeSeries_FamilyCounts_in_ipc_cpc_transpo_from_all.csv')
    dfYearLevel.index = dfYearLevel['earliest_filing_year']
    # Non-Exclusive
    maskYears = dfYearLevel['earliest_filing_year'].isin(range(1970, 2018))
    fig, ax = plt.subplots()
    for var in ['Count_BatTrans', 'Count_BatNotTr']:
        plt.plot(dfYearLevel[maskYears][var], label=dict_var_colors[var][0], color=dict_var_colors[var][1], linewidth=3)
    plt.legend(ncol=1)
    plt.xlabel('Earliest Filing Year')
    plt.ylabel('Number of DocDB Families')
    plt.xticks(list(range(1970, 2018, 10)))
    plt.savefig(PATHS.figures / 'transport_nontransport/BatTrNonTr_familylevel.png')
    plt.yscale('log')
    plt.savefig(PATHS.figures / 'transport_nontransport/BatTrNonTr_familylevel_log.png')
    plt.close()
    # Share
    dfYearLevel = dfYearLevel[maskYears]
    dfYearLevel['share_transInbat'] = dfYearLevel['Count_BatTrans'] / dfYearLevel['Count_Bat']
    plt.plot(dfYearLevel['share_transInbat'], color=pal[-3], linewidth=3)
    plt.legend(ncol=1)
    plt.xlabel('Earliest Filing Year')
    plt.ylabel('Share of battery families related to transport')
    plt.ylim(0, 1)
    plt.savefig(PATHS.figures / 'transport_nontransport/BatTrNonTr_familylevel_share.png')
    plt.close()


def plot_BatTrNonTr_OEM():
    ts_counts = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/TimeSeries_FamilyCounts_in_ipc_cpc_transpo_from_OEM_and_Subsidiaries.csv')
    maskYears = ts_counts['earliest_filing_year'].isin(range(1990, 2016))
    fig, ax = plt.subplots()
    for var in ['Count_BatTrans_excl', 'Count_BatNotTr_excl', 'Count_FC_excl']:
        plt.plot(ts_counts[maskYears]['earliest_filing_year'], ts_counts[maskYears][var], label=dict_var_colors[var][0], color=dict_var_colors[var][1], linewidth=3)
    plt.legend(ncol=1)
    plt.xlabel('Year')
    plt.ylabel('Number of Patent Families')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'BatTrNonTr_OEM.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'BatTrNonTr_OEM.pdf', bbox_inches='tight')
    plt.close()
    ts_counts = ts_counts[maskYears]
    ts_counts['share_transInbat'] = ts_counts['Count_BatTrans'] / ts_counts['Count_Bat']
    plt.plot(ts_counts['earliest_filing_year'], ts_counts['share_transInbat'], color=pal[-3], linewidth=3)
    plt.legend(ncol=1)
    plt.xlabel('Earliest Filing Year')
    plt.ylabel('Share of battery families related to transport')
    plt.ylim(0, 1)
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'BatTrNonTr_OEM_share.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'BatTrNonTr_OEM_share.pdf', bbox_inches='tight')
    plt.close()



def plot_BatTrNonTr_bysector():
    df_batfc_fam = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/Families_Bat_FC_naics_info.csv')
    df_batfc_fam['docdb_bvd'] = df_batfc_fam['docdb_family_id'].astype(str) + '_' + df_batfc_fam['bvdid'].fillna('')
    maskyears = df_batfc_fam['earliest_filing_year'].isin(range(1990, 2016))
    maskBAT = df_batfc_fam['Sub-sector_exclusive'] == 'batteries'
    maskFC = df_batfc_fam['Sub-sector_exclusive'] == 'fuel cells'
    maskBATtrans = maskBAT & ((df_batfc_fam['BatTrans'] == 'Trans')|(df_batfc_fam['BatTrans'] == 'NonTrans,Trans'))
    maskBATnontrans = maskBAT & (df_batfc_fam['BatTrans'] == 'NonTrans')
    pal2 = sns.color_palette('Greys')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    for ax, mask in zip([ax1, ax2], [maskBATtrans, maskBATnontrans]):
        # MOTOR VEHICLE
        maskMotor = df_batfc_fam['MotorVehicle'].notnull() & (df_batfc_fam['MotorVehicle'] > 0)
        maskOEMsorSubsi = df_batfc_fam['OEMorSubsidiary'].notnull() & (df_batfc_fam['OEMorSubsidiary'] > 0)
        masksector = maskMotor | maskOEMsorSubsi
        data1 = df_batfc_fam[mask & maskyears & masksector].groupby('earliest_filing_year')['docdb_bvd'].nunique()
        ax.plot(data1, label='Motor Vehicle', color=pal2[-1], linewidth=3)
        # OTHER
        data2 = df_batfc_fam[mask & maskyears].groupby('earliest_filing_year')['docdb_bvd'].nunique()
        ax.set_ylim(0, 6500)
        ax_right = ax.twinx()
        data3 = 100 * data1 / data2
        ax_right.plot(data3, label='Percent of Total from Motor Vehicle', color=pal2[-1], linewidth=2, linestyle='--')
        ax_right.set_ylim(0, 55)
        ax_right.set_ylabel('Percent from Motor Vehicle (%)')
        # ELECTRONICS
        masksector = df_batfc_fam['Electronics'].notnull() & (df_batfc_fam['Electronics'] > 0)
        data = df_batfc_fam[mask & maskyears & masksector].groupby('earliest_filing_year')['docdb_bvd'].nunique()
        ax.plot(data, label='Electronics', color=pal[-1], linewidth=3)
        # MACHINERY
        masksector = df_batfc_fam['MachineryChemical'].notnull() & (df_batfc_fam['MachineryChemical'] > 0)
        data = df_batfc_fam[mask & maskyears & masksector].groupby('earliest_filing_year')['docdb_bvd'].nunique()
        ax.plot(data, label='Machinery and Chemical Manufacturing', color=pal[-2], linewidth=2)
        # TRANSPORT
        masksector = df_batfc_fam['OtherTransport'].notnull() & (df_batfc_fam['OtherTransport'] > 0)
        data = df_batfc_fam[mask & maskyears & masksector].groupby('earliest_filing_year')['docdb_bvd'].nunique()
        ax.plot(data, label='Other Transport', color=pal[-3], linewidth=2)
        ax.set_xlabel('Year')
        ax.set_ylabel('Number of Patent Families')
        h1, l1 = ax.get_legend_handles_labels()
        h2, l2 = ax_right.get_legend_handles_labels()
    plt.subplots_adjust(bottom=0.15, wspace=.4)
    fig.legend(h2 + h1, l2 + l1, loc='lower center', ncol=5, fontsize=8, fancybox=True, frameon=True, edgecolor='black', framealpha=.5, borderpad=.7)
    ax1.set_title('Transport Battery Patents')
    ax2.set_title('Non-Transport Battery Patents')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'BatTrNonTr_Patenting_bysector.pdf', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'BatTrNonTr_Patenting_bysector.png', bbox_inches='tight')
    plt.close()



def stackplot_BatTrNonTr_activesupplierpatenting_bytype():
    dfsupplieryear = reading_datasets.read_supplier_firmlevel_panel()
    beg_year, end_year = 1990, 2016
    maskActive = dfsupplieryear['Active'] == 1
    maskYears = (dfsupplieryear['earliest_filing_year'] >= beg_year) & (dfsupplieryear['earliest_filing_year'] < end_year)
    data = dfsupplieryear[maskActive & maskYears].groupby('earliest_filing_year').sum()
    # Now get data for stacked areas
    """ we're assuming minimal double patenting by firms as we sum across firms"""
    maskYears = (dfsupplieryear['earliest_filing_year'] >= 2003) & (dfsupplieryear['earliest_filing_year'] < 2016)
    maskActive20032017 = (dfsupplieryear['Active'] == 1) & maskYears
    autosectorMask = dfsupplieryear['4digitNAICS'].notnull() & (dfsupplieryear['4digitNAICS'].str.contains('3361|3362|3363', regex=True)) & maskActive20032017
    oldguardMask = (dfsupplieryear['oldguard'] == 1) & (~autosectorMask) & maskActive20032017
    newguardMask = (dfsupplieryear['oldguard'] == 0) & (~autosectorMask) & maskActive20032017
    cols = ['Count_Bat_excl', 'Count_BatTrans_excl', 'Count_BatNotTr_excl']
    autosuppliers_innovation_ts = dfsupplieryear[autosectorMask].groupby(['earliest_filing_year'])[cols].sum().reset_index().rename(columns={c: 'automotive_'+c for c in cols})
    oldguard_supp_ts = dfsupplieryear[oldguardMask].groupby(['earliest_filing_year'])[cols].sum().reset_index().rename(columns = {c: 'oldguard_'+c for c in cols})
    newguard_supp_ts = dfsupplieryear[newguardMask].groupby(['earliest_filing_year'])[cols].sum().reset_index().rename(columns = {c: 'newguard_'+c for c in cols})
    active_supp_ts = dfsupplieryear[maskActive20032017].groupby(['earliest_filing_year'])[cols].sum().reset_index()
    maskYears = (active_supp_ts['earliest_filing_year'] >= 2003) & (active_supp_ts['earliest_filing_year'] < 2016)
    supp_ts_bygroup = active_supp_ts[maskYears].merge(autosuppliers_innovation_ts, on='earliest_filing_year', how='left')
    supp_ts_bygroup = supp_ts_bygroup.merge(oldguard_supp_ts, on='earliest_filing_year', how='left')
    supp_ts_bygroup = supp_ts_bygroup.merge(newguard_supp_ts, on='earliest_filing_year', how='left')
    supp_ts_bygroup.fillna(0, inplace=True)
    # PLOT FIGURES
    pal2 = sns.color_palette('Blues')
    # TRANSPORT VS NON TRANSPORT BATTERY PATENTS
    plt.plot(data['Count_FC_excl'], label='Fuel Cells Patents', color=dict_var_colors['Count_FC_excl'][1], linewidth=5)
    plt.plot(data['Count_Bat_excl'], label='Battery Patents:', color=dict_var_colors['Count_Bat_excl'][1], linewidth=5)
    plt.stackplot(data.index,
                  data['Count_BatTrans_excl'],
                  data['Count_BatNotTr_excl'],
                  labels=['  Transport related', '  Non-transport'],
                  colors=[pal2[-6], pal2[-4]], linewidth=0)
    plt.ylabel('Number of Patent Families')
    plt.legend(loc='upper left', ncol=1)
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesupplierspatenting_trans_vs_nontrans_battery_stackplot_bytypeofsuppliers.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesupplierspatenting_trans_vs_nontrans_battery_stackplot_bytypeofsuppliers.pdf', bbox_inches='tight')
    plt.close()
    # TRANSPORT BATTERY PATENTS
    pal2 = sns.color_palette('Blues')
    plt.plot(data['Count_FC_excl'], label='Fuel Cells Patents', color=dict_var_colors['Count_FC_excl'][1], linewidth=5)
    plt.plot(data['Count_BatTrans_excl'], label='Transport Battery Patents:', color=dict_var_colors['Count_BatTrans_excl'][1], linewidth=5)
    plt.stackplot(supp_ts_bygroup['earliest_filing_year'],
                  supp_ts_bygroup['automotive_Count_BatTrans_excl'],
                  supp_ts_bygroup['oldguard_Count_BatTrans_excl'],
                  supp_ts_bygroup['newguard_Count_BatTrans_excl'],
                  labels=['  from suppliers in Motor Vehicles', '  from outside Motor Vehicles - Old', '  from outside Motor Vehicles - New'],
                  colors=[pal2[-6], pal2[-4], pal2[-2]], linewidth=0)
    plt.ylabel('Number of Patent Families')
    plt.legend(loc='upper left', ncol=1)
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesupplierspatenting_transport_battery_stackplot_bytypeofsuppliers.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesupplierspatenting_transport_battery_stackplot_bytypeofsuppliers.pdf', bbox_inches='tight')
    plt.close()
    # NON TRANSPORT BATTERY PATENTS
    plt.plot(data['Count_FC_excl'], label='Fuel Cells Patents', color=dict_var_colors['Count_FC_excl'][1], linewidth=5)
    plt.plot(data['Count_BatNotTr_excl'], label='Non-Transport Battery Patents:', color=dict_var_colors['Count_BatNotTr_excl'][1], linewidth=5)
    plt.stackplot(supp_ts_bygroup['earliest_filing_year'],
                  supp_ts_bygroup['automotive_Count_BatNotTr_excl'],
                  supp_ts_bygroup['oldguard_Count_BatNotTr_excl'],
                  supp_ts_bygroup['newguard_Count_BatNotTr_excl'],
                  labels=['  from suppliers in Motor Vehicles', '  from outside Motor Vehicles - Old', '  from outside Motor Vehicles - New'],
                  colors=[pal2[-6], pal2[-4], pal2[-2]], linewidth=0)
    plt.ylabel('Number of Patent Families')
    plt.legend(loc='upper left', ncol=1)
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesupplierspatenting_nontransport_battery_stackplot_bytypeofsuppliers.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesupplierspatenting_nontransport_battery_stackplot_bytypeofsuppliers.pdf', bbox_inches='tight')
    plt.close()


def plot_activesuppliers_transvsnontrans():
    dfsupplieryear = reading_datasets.read_supplier_firmlevel_panel()
    list(dfsupplieryear.columns)
    # select those that are at some point active after 2003
    maskYears = (dfsupplieryear['earliest_filing_year'] >= 2003) & (dfsupplieryear['earliest_filing_year'] < 2016)
    maskActive20032017 = (dfsupplieryear['Active'] == 1) & maskYears
    # and those outside motor vehicles
    autosectorMask = dfsupplieryear['4digitNAICS'].notnull() & (dfsupplieryear['4digitNAICS'].str.contains('3361|3362|3363', regex=True)) & maskActive20032017
    newguardMask = (dfsupplieryear['oldguard'] == 0) & (~autosectorMask) & maskActive20032017
    bvdids_of_newguardsuppliers = dfsupplieryear[newguardMask]['bvdid'].drop_duplicates().to_list()
    # Now get panel for this subset of suppliers
    mask_newguard = dfsupplieryear['bvdid'].isin(bvdids_of_newguardsuppliers)
    dfsupplieryear = dfsupplieryear[mask_newguard]
    # limit time span
    beg_year, end_year = 1990, 2016
    maskYears = (dfsupplieryear['earliest_filing_year'] >= beg_year) & (dfsupplieryear['earliest_filing_year'] < end_year)
    dfsupplieryear = dfsupplieryear[maskYears]
    # Find year when they first get connected
    year_of_connection = dfsupplieryear[(dfsupplieryear['Active'] == 1)].groupby('bvdid')['earliest_filing_year'].min().rename('YearConnect')
    dfsupplieryear = dfsupplieryear.merge(year_of_connection, on='bvdid', how='left')
    dfsupplieryear['RelativeYear'] = dfsupplieryear['earliest_filing_year'] - dfsupplieryear['YearConnect']
    # look at a few years beofre and after connection
    maskTimespan = dfsupplieryear['RelativeYear'].isin(list(range(-3, 4)))
    supliertotalcount = dfsupplieryear[maskTimespan].groupby('bvdid')['Count_Bat_excl'].sum().rename('totalcount')
    dfsupplieryear = dfsupplieryear.merge(supliertotalcount, on='bvdid', how='left')
    dfsupplieryear = dfsupplieryear.sort_values(by='totalcount')
    cols = ['name', 'bvdid', 'totalcount', 'earliest_filing_year', 'RelativeYear', 'Count', 'Count_Bat_excl', 'Count_BatTrans_excl', 'Count_BatNotTr_excl']
    # construct indicator for balanced panel
    maskTimespan = dfsupplieryear['RelativeYear'].isin(list(range(-3, 4)))
    NbrYears = dfsupplieryear[maskTimespan].groupby('bvdid')['RelativeYear'].nunique().rename('Nyears')
    dfsupplieryear = dfsupplieryear.merge(NbrYears, on='bvdid', how='left')
    # PLOT
    maskBalanced = dfsupplieryear['Nyears'] == 7
    maskTimespan = dfsupplieryear['RelativeYear'].isin(list(range(-3, 4)))
    data = dfsupplieryear[maskTimespan & maskBalanced].groupby('RelativeYear').mean()
    fig, ax = plt.subplots()
    plt.plot(data['Count_BatTrans_excl'], label='Transport Battery', color=dict_var_colors['Count_BatTrans_excl'][1], linewidth=3)
    plt.plot(data['Count_BatNotTr_excl'], label='Non-Transport Battery', color=dict_var_colors['Count_BatNotTr_excl'][1], linewidth=3)
    plt.legend(ncol=1)
    plt.xlabel('Years after first connection to an OEM')
    plt.ylabel('Mean Patent Count of "New" Suppliers')
    plt.ylim(0, 25)
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesuppliers_trans_vs_nontrans_battery.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'transport_nontransport' / 'activesuppliers_trans_vs_nontrans_battery.pdf', bbox_inches='tight')
    plt.close()
    # OUTPUT SCATTERPLOT
    maskTimespan = dfsupplieryear['RelativeYear'].isin(list(range(-3, 4)))
    maskBalanced = dfsupplieryear['Nyears'] == 7
    preperiod = dfsupplieryear['RelativeYear'] <= 0
    postperiod = dfsupplieryear['RelativeYear'] >= 0
    list_firms = dfsupplieryear[maskTimespan & maskBalanced]['bvdid'].drop_duplicates().to_list()
    for variable in ['BatTrans', 'BatNotTr']:
        all_rows = []
        for bvdidsupplier in list_firms:
            maskSupplier = dfsupplieryear['bvdid'] == bvdidsupplier
            XY = dfsupplieryear[maskTimespan & maskBalanced & preperiod & maskSupplier][['RelativeYear', f'Count_{variable}_excl']]
            slopePre, intercept, r_value, p_value, std_err = linregress(XY['RelativeYear'], XY[f'Count_{variable}_excl'])
            XY = dfsupplieryear[maskTimespan & maskBalanced & postperiod & maskSupplier][['RelativeYear', f'Count_{variable}_excl']]
            slopePost, intercept, r_value, p_value, std_err = linregress(XY['RelativeYear'], XY[f'Count_{variable}_excl'])
            all_rows.append({'bvdid': bvdidsupplier, 'PreSlope': slopePre, 'PostSlope': slopePost})
        df_slopes = pd.DataFrame(all_rows)
        plt.scatter(df_slopes['PreSlope'], df_slopes['PostSlope'], zorder=3)
        plt.xlabel('Slope Before Connecting')
        plt.ylabel('Slope After Connecting')
        plt.axhline(0, color='lightgray', zorder=1)  # Add horizontal line at y=0
        plt.axvline(0, color='lightgray', zorder=1)  # Add vertical line at x=0
        plt.savefig(PATHS.figures / 'transport_nontransport' / f'activesuppliers_scatter_slopes_beforeafterconnecting_{variable}.png', bbox_inches='tight')
        plt.savefig(PATHS.figures / 'transport_nontransport' / f'activesuppliers_scatter_slopes_beforeafterconnecting_{variable}.pdf', bbox_inches='tight')
        plt.close()


